{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Installing packages:\n",
      "\t.package(url: \"https://github.com/mxcl/Path.swift\", from: \"0.16.1\")\n",
      "\t\tPath\n",
      "\t.package(url: \"https://github.com/saeta/Just\", from: \"0.7.2\")\n",
      "\t\tJust\n",
      "\t.package(url: \"https://github.com/latenitesoft/NotebookExport\", from: \"0.5.0\")\n",
      "\t\tNotebookExport\n",
      "With SwiftPM flags: []\n",
      "Working in: /tmp/tmpqyuwt4mg/swift-install\n",
      "warning: /home/jupyter/swift/usr/bin/swiftc: /home/jupyter/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jupyter/swift/usr/bin/swiftc)\n",
      "/home/jupyter/swift/usr/bin/swift: /home/jupyter/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jupyter/swift/usr/bin/swift)\n",
      "warning: /home/jupyter/swift/usr/bin/swiftc: /home/jupyter/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jupyter/swift/usr/bin/swiftc)\n",
      "/home/jupyter/swift/usr/bin/swift: /home/jupyter/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jupyter/swift/usr/bin/swift)\n",
      "warning: /home/jupyter/swift/usr/bin/swiftc: /home/jupyter/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jupyter/swift/usr/bin/swiftc)\n",
      "/home/jupyter/swift/usr/bin/swift: /home/jupyter/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jupyter/swift/usr/bin/swift)\n",
      "warning: /home/jupyter/swift/usr/bin/swiftc: /home/jupyter/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jupyter/swift/usr/bin/swiftc)\n",
      "/home/jupyter/swift/usr/bin/swift: /home/jupyter/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jupyter/swift/usr/bin/swift)\n",
      "/home/jupyter/swift/usr/bin/swiftc: /home/jupyter/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jupyter/swift/usr/bin/swiftc)\n",
      "/home/jupyter/swift/usr/bin/swiftc: /home/jupyter/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jupyter/swift/usr/bin/swiftc)[1/2] Compiling jupyterInstalledPackages jupyterInstalledPackages.swift\n",
      "/home/jupyter/swift/usr/bin/swift: /home/jupyter/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jupyter/swift/usr/bin/swift)\n",
      "[2/3] Merging module jupyterInstalledPackages\n",
      "/home/jupyter/swift/usr/bin/swift: /home/jupyter/anaconda3/lib/libuuid.so.1: no version information available (required by /home/jupyter/swift/usr/bin/swift)\n",
      "Initializing Swift...\n",
      "Installation complete!\n"
     ]
    }
   ],
   "source": [
    "%install-location $cwd/swift-install\n",
    "%install '.package(url: \"https://github.com/mxcl/Path.swift\", from: \"0.16.1\")' Path\n",
    "%install '.package(url: \"https://github.com/saeta/Just\", from: \"0.7.2\")' Just\n",
    "%install '.package(url: \"https://github.com/latenitesoft/NotebookExport\", from: \"0.5.0\")' NotebookExport"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Currently there's a bug in swift-jupyter which requires we define any custom operators here:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// export\n",
    "precedencegroup ExponentiationPrecedence {\n",
    "    associativity: right\n",
    "    higherThan: MultiplicationPrecedence\n",
    "}\n",
    "infix operator ** : ExponentiationPrecedence\n",
    "\n",
    "precedencegroup CompositionPrecedence { associativity: left }\n",
    "infix operator >| : CompositionPrecedence"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Getting the MNIST dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use a couple of libraries here:\n",
    "- [Just](https://github.com/JustHTTP/Just) does the equivalent of requests in python\n",
    "- [Path](https://github.com/mxcl/Path.swift) is even better than its python counterpart\n",
    "- Foundation is the base library from Apple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "import Foundation\n",
    "import Just\n",
    "import Path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will need to gunzip, untar or unzip files we download, so instead of grabbing one library for each, we implement a function that can execute any shell command."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "public extension String {\n",
    "    @discardableResult\n",
    "    func shell(_ args: String...) -> String\n",
    "    {\n",
    "        let (task,pipe) = (Process(),Pipe())\n",
    "        task.executableURL = URL(fileURLWithPath: self)\n",
    "        (task.arguments,task.standardOutput) = (args,pipe)\n",
    "        do    { try task.run() }\n",
    "        catch { print(\"Unexpected error: \\(error).\") }\n",
    "\n",
    "        let data = pipe.fileHandleForReading.readDataToEndOfFile()\n",
    "        return String(data: data, encoding: String.Encoding.utf8) ?? \"\"\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 15M\r\n",
      "-rw-rw-r-- 1 jupyter jupyter  27K Jul 17 22:52 00_load_data.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter  43K Jul 17 22:52 00a_intro_and_float.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter  43K Jul 17 22:52 01_matmul.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter  30K Jul 17 22:52 01a_fastai_layers.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter  50K Jul 17 22:52 02_fully_connected.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter  23K Jul 17 22:52 02a_why_sqrt5.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter  21K Jul 17 22:52 02b_initializing.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter  16K Jul 17 22:52 02c_autodiff.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter  29K Jul 17 22:52 03_minibatch_training.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter  32K Jul 17 22:52 04_callbacks.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter  79K Jul 17 22:52 05_anneal.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter  57K Jul 17 22:52 05b_early_stopping.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter 132K Jul 17 22:52 06_cuda.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter  39K Jul 17 22:52 07_batchnorm.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter  15K Jul 17 22:52 07b_batchnorm_lesson.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter 830K Jul 17 22:52 08_data_block.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter 5.2K Jul 17 22:52 08a_heterogeneous_dictionary.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter 772K Jul 17 22:52 08b_data_block_opencv.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter 526K Jul 17 22:52 08c_data_block-lightlyfunctional.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter 438K Jul 17 22:52 08c_data_block_generic.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter  42K Jul 17 22:52 09_optimizer.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter  27K Jul 17 22:52 10_mixup_ls.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter 433K Jul 17 22:52 11_imagenette.ipynb\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 FastaiNotebook_00_load_data\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 FastaiNotebook_01_matmul\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 FastaiNotebook_01a_fastai_layers\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 FastaiNotebook_02_fully_connected\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 FastaiNotebook_02a_why_sqrt5\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 FastaiNotebook_03_minibatch_training\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 FastaiNotebook_04_callbacks\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 FastaiNotebook_05_anneal\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 FastaiNotebook_05b_early_stopping\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 FastaiNotebook_06_cuda\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 FastaiNotebook_07_batchnorm\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 FastaiNotebook_08_data_block\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 FastaiNotebook_08a_heterogeneous_dictionary\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 FastaiNotebook_08c_data_block_generic\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 FastaiNotebook_09_optimizer\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 FastaiNotebook_10_mixup_ls\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 FastaiNotebook_11_imagenette\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 17 22:52 Imagenette\r\n",
      "-rw-rw-r-- 1 jupyter jupyter 7.3K Jul 16 23:28 Memory leak.ipynb\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 Runnable11\r\n",
      "drwxrwxr-x 6 jupyter jupyter 4.0K Jul 16 20:46 SwiftCV\r\n",
      "drwxrwxr-x 5 jupyter jupyter 4.0K Jul 16 20:46 SwiftSox\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 SwiftVips\r\n",
      "-rw-rw-r-- 1 jupyter jupyter 105K Jul 17 22:52 audio.ipynb\r\n",
      "-rw-rw-r-- 1 jupyter jupyter 303K Jul 17 22:52 c_interop_examples.ipynb\r\n",
      "drwxrwxr-x 3 jupyter jupyter 4.0K Jul 16 20:46 datablock\r\n",
      "-rw-rw-r-- 1 jupyter jupyter 5.1K Jul 16 23:17 memory_leak.swift\r\n",
      "-rw-rw-r-- 1 jupyter jupyter 780K Jul 17 22:52 opencv_integration_example.ipynb\r\n",
      "drwxrwxr-x 4 jupyter jupyter 4.0K Jul 16 20:47 swift-install\r\n",
      "-rw------- 1 jupyter jupyter 9.5M Jul 17 22:50 train-images-idx3-ubyte.gz\r\n",
      "\r\n"
     ]
    }
   ],
   "source": [
    "print(\"/bin/ls\".shell(\"-lh\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To download a file, we use the `Just` library."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "public func downloadFile(_ url: String, dest: String? = nil, force: Bool = false) {\n",
    "    let dest_name = dest ?? (Path.cwd/url.split(separator: \"/\").last!).string\n",
    "    let url_dest = URL(fileURLWithPath: (dest ?? (Path.cwd/url.split(separator: \"/\").last!).string))\n",
    "    if !force && Path(dest_name)!.exists { return }\n",
    "\n",
    "    print(\"Downloading \\(url)...\")\n",
    "\n",
    "    if let cts = Just.get(url).content {\n",
    "        do    {try cts.write(to: URL(fileURLWithPath:dest_name))}\n",
    "        catch {print(\"Can't write to \\(url_dest).\\n\\(error)\")}\n",
    "    } else {\n",
    "        print(\"Can't reach \\(url)\")\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "downloadFile(\"https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we will need to read our data and convert it into a `Tensor`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "import TensorFlow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following is generic over the element type on the return value.  We could define two functions like this:\n",
    "\n",
    "```swift\n",
    "func loadMNIST(training: Bool, labels: Bool, path: Path, flat: Bool) -> Tensor<Float> {\n",
    "func loadMNIST(training: Bool, labels: Bool, path: Path, flat: Bool) -> Tensor<Int32> {\n",
    "```\n",
    "\n",
    "but that would be boring.  So we make loadMNIST take a \"type parameter\" `T` which indicates what sort of element type to load a tensor into."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "error: <Cell 9>:15:36: error: cannot invoke 'map' with an argument list of type '(@escaping (Tensor<T>) -> T?)'\n    if labels { return Tensor(data.map(T.init)) }\n                                   ^\n\n<Cell 9>:15:36: note: expected an argument list of type '((Self.Element) throws -> T)'\n    if labels { return Tensor(data.map(T.init)) }\n                                   ^\n\nerror: <Cell 9>:16:36: error: cannot invoke 'map' with an argument list of type '(@escaping (Tensor<T>) -> T?)'\n    else      { return Tensor(data.map(T.init)).reshaped(to: shape)}\n                                   ^\n\n<Cell 9>:16:36: note: expected an argument list of type '((Self.Element) throws -> T)'\n    else      { return Tensor(data.map(T.init)).reshaped(to: shape)}\n                                   ^\n\n"
     ]
    }
   ],
   "source": [
    "func loadMNIST<T>(training: Bool, labels: Bool, path: Path, flat: Bool) -> Tensor<T> {\n",
    "    let split = training ? \"train\" : \"t10k\"\n",
    "    let kind = labels ? \"labels\" : \"images\"\n",
    "    let batch = training ? 60000 : 10000\n",
    "    let shape: TensorShape = labels ? [batch] : (flat ? [batch, 784] : [batch, 28, 28])\n",
    "    let dropK = labels ? 8 : 16\n",
    "    let baseUrl = \"https://storage.googleapis.com/cvdf-datasets/mnist/\"\n",
    "    let fname = split + \"-\" + kind + \"-idx\\(labels ? 1 : 3)-ubyte\"\n",
    "    let file = path/fname\n",
    "    if !file.exists {\n",
    "        downloadFile(\"\\(baseUrl)\\(fname).gz\", dest:(path/\"\\(fname).gz\").string)\n",
    "        \"/bin/gunzip\".shell(\"-fq\", (path/\"\\(fname).gz\").string)\n",
    "    }\n",
    "    let data = try! Data(contentsOf: URL(fileURLWithPath: file.string)).dropFirst(dropK)\n",
    "    if labels { return Tensor(data.map(T.init)) }\n",
    "    else      { return Tensor(data.map(T.init)).reshaped(to: shape)}\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But this doesn't work because S4TF can't just put any type of data inside a `Tensor`. We have to tell it that this type:\n",
    "- is a type that TF can understand and deal with\n",
    "- is a type that can be applied to the data we read in the byte format\n",
    "\n",
    "We do this by defining a protocol called `ConvertibleFromByte` that inherits from `TensorFlowScalar`. That takes care of the first requirement. The second requirement is dealt with by asking for an `init` method that takes `UInt8`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "protocol ConvertibleFromByte: TensorFlowScalar {\n",
    "    init(_ d:UInt8)\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we need to say that `Float` and `Int32` conform to that protocol. They already have the right initializer so we don't have to code anything."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "extension Float : ConvertibleFromByte {}\n",
    "extension Int32 : ConvertibleFromByte {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lastly, we write a convenience method for all types that conform to the `ConvertibleFromByte` protocol, that will convert some raw data to a `Tensor` of that type."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "extension Data {\n",
    "    func asTensor<T:ConvertibleFromByte>() -> Tensor<T> {\n",
    "        return Tensor(map(T.init))\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now we can write a generic `loadMNIST` function that can returns tensors of `Float` or `Int32`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "func loadMNIST<T: ConvertibleFromByte>\n",
    "            (training: Bool, labels: Bool, path: Path, flat: Bool) -> Tensor<T> {\n",
    "    let split = training ? \"train\" : \"t10k\"\n",
    "    let kind = labels ? \"labels\" : \"images\"\n",
    "    let batch = training ? 60000 : 10000\n",
    "    let shape: TensorShape = labels ? [batch] : (flat ? [batch, 784] : [batch, 28, 28])\n",
    "    let dropK = labels ? 8 : 16\n",
    "    let baseUrl = \"https://storage.googleapis.com/cvdf-datasets/mnist/\"\n",
    "    let fname = split + \"-\" + kind + \"-idx\\(labels ? 1 : 3)-ubyte\"\n",
    "    let file = path/fname\n",
    "    if !file.exists {\n",
    "        downloadFile(\"\\(baseUrl)\\(fname).gz\", dest:(path/\"\\(fname).gz\").string)\n",
    "        \"/bin/gunzip\".shell(\"-fq\", (path/\"\\(fname).gz\").string)\n",
    "    }\n",
    "    let data = try! Data(contentsOf: URL(fileURLWithPath: file.string)).dropFirst(dropK)\n",
    "    if labels { return data.asTensor() }\n",
    "    else      { return data.asTensor().reshaped(to: shape)}\n",
    "}\n",
    "\n",
    "public func loadMNIST(path:Path, flat:Bool = false)\n",
    "        -> (Tensor<Float>, Tensor<Int32>, Tensor<Float>, Tensor<Int32>) {\n",
    "    try! path.mkdir(.p)\n",
    "    return (\n",
    "        loadMNIST(training: true,  labels: false, path: path, flat: flat) / 255.0,\n",
    "        loadMNIST(training: true,  labels: true,  path: path, flat: flat),\n",
    "        loadMNIST(training: false, labels: false, path: path, flat: flat) / 255.0,\n",
    "        loadMNIST(training: false, labels: true,  path: path, flat: flat)\n",
    "    )\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will store mnist in this folder so that we don't download it each time we run a notebook:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "public let mnistPath = Path.home/\".fastai\"/\"data\"/\"mnist_tst\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The default returns mnist in the image format:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "▿ [60000, 28, 28]\n",
       "  ▿ dimensions : 3 elements\n",
       "    - 0 : 60000\n",
       "    - 1 : 28\n",
       "    - 2 : 28\n"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "let (xTrain, yTrain, xValid, yValid) = loadMNIST(path: mnistPath)\n",
    "xTrain.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also ask for it in its flattened form:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "▿ [60000, 784]\n",
       "  ▿ dimensions : 2 elements\n",
       "    - 0 : 60000\n",
       "    - 1 : 784\n"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "let (xTrain, yTrain, xValid, yValid) = loadMNIST(path: mnistPath, flat: true)\n",
    "xTrain.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Timing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is our time function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export \n",
    "import Dispatch\n",
    "\n",
    "// ⏰Time how long it takes to run the specified function, optionally taking\n",
    "// the average across a number of repetitions.\n",
    "public func time(repeating: Int = 1, _ f: () -> ()) {\n",
    "    guard repeating > 0 else { return }\n",
    "    \n",
    "    // Warmup\n",
    "    if repeating > 1 { f() }\n",
    "    \n",
    "    var times = [Double]()\n",
    "    for _ in 1...repeating {\n",
    "        let start = DispatchTime.now()\n",
    "        f()\n",
    "        let end = DispatchTime.now()\n",
    "        let nanoseconds = Double(end.uptimeNanoseconds - start.uptimeNanoseconds)\n",
    "        let milliseconds = nanoseconds / 1e6\n",
    "        times.append(milliseconds)\n",
    "    }\n",
    "    print(\"average: \\(times.reduce(0.0, +)/Double(times.count)) ms,   \" +\n",
    "          \"min: \\(times.reduce(times[0], min)) ms,   \" +\n",
    "          \"max: \\(times.reduce(times[0], max)) ms\")\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "average: 485.15768909999997 ms,   min: 474.723727 ms,   max: 501.671377 ms\r\n"
     ]
    }
   ],
   "source": [
    "time(repeating: 10) {\n",
    "    _ = loadMNIST(training: false, labels: false, path: mnistPath, flat: false) as Tensor<Float>\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Searching for a specific pattern with a regular expression isn't easy in swift. The good thing is that with an extension, we can make it easy for us!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// export\n",
    "public extension String {\n",
    "    func findFirst(pat: String) -> Range<String.Index>? {\n",
    "        return range(of: pat, options: .regularExpression)\n",
    "    }\n",
    "    func hasMatch(pat: String) -> Bool {\n",
    "        return findFirst(pat:pat) != nil\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The foundation library isn't always the most convenient to use... This is how the first line of the following cell is written in it.\n",
    "\n",
    "```swift\n",
    "let url_fname = URL(fileURLWithPath: fname)\n",
    "let last = fname.lastPathComponent\n",
    "let out_fname = (url_fname.deletingLastPathComponent().appendingPathComponent(\"FastaiNotebooks\", isDirectory: true)\n",
    "     .appendingPathComponent(\"Sources\", isDirectory: true)\n",
    "     .appendingPathComponent(\"FastaiNotebooks\", isDirectory: true).appendingPathComponent(last)\n",
    "     .deletingPathExtension().appendingPathExtension(\"swift\"))\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This function parses the underlying json behind a notebook to keep the code in the cells marked with `//export`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "public func notebookToScript(fname: Path){\n",
    "    let newname = fname.basename(dropExtension: true)+\".swift\"\n",
    "    let url = fname.parent/\"FastaiNotebooks/Sources/FastaiNotebooks\"/newname\n",
    "    do {\n",
    "        let data = try Data(contentsOf: fname.url)\n",
    "        let jsonData = try JSONSerialization.jsonObject(with: data, options: .allowFragments) as! [String: Any]\n",
    "        let cells = jsonData[\"cells\"] as! [[String:Any]]\n",
    "        var module = \"\"\"\n",
    "/*\n",
    "THIS FILE WAS AUTOGENERATED! DO NOT EDIT!\n",
    "file to edit: \\(fname.lastPathComponent)\n",
    "\n",
    "*/\n",
    "        \n",
    "\"\"\"\n",
    "        for cell in cells {\n",
    "            if let source = cell[\"source\"] as? [String], !source.isEmpty, \n",
    "                   source[0].hasMatch(pat: #\"^\\s*//\\s*export\\s*$\"#) {\n",
    "                module.append(\"\\n\" + source[1...].joined() + \"\\n\")\n",
    "            }\n",
    "        }\n",
    "        try module.write(to: url, encoding: .utf8)\n",
    "    } catch {\n",
    "        print(\"Can't read the content of \\(fname)\")\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And this will do all the notebooks in a given folder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// export\n",
    "public func exportNotebooks(_ path: Path) {\n",
    "    for entry in try! path.ls()\n",
    "    where entry.kind == Entry.Kind.file && \n",
    "          entry.path.basename().hasMatch(pat: #\"^\\d*_.*ipynb$\"#) {\n",
    "        print(\"Converting \\(entry)\")\n",
    "        notebookToScript(fname: entry.path)\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Can't read the content of /home/jupyter/notebooks/swift/00_load_data.ipynb\r\n"
     ]
    }
   ],
   "source": [
    "notebookToScript(fname: Path.cwd/\"00_load_data.ipynb\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But now that we implemented it from scratch we're allowed to use it as a package ;). NotebookExport has been written by pcuenq\n",
    "and will make our lives easier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "success\r\n"
     ]
    }
   ],
   "source": [
    "import NotebookExport\n",
    "let exporter = NotebookExport(Path.cwd/\"00_load_data.ipynb\")\n",
    "print(exporter.export(usingPrefix: \"FastaiNotebook_\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Swift",
   "language": "swift",
   "name": "swift"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
